# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper routines for quantization."""

from typing import Any

import chex
import jax.numpy as jnp
from flax import struct


# pylint:disable=no-value-for-parameter
@struct.dataclass
class QuantizedValue:
    """State associated with quantized value."""

    quantized: chex.Array
    diagonal: chex.Array  # Diagonal (if extract_diagonal is set)
    bucket_size: chex.Array
    quantized_dtype: jnp.dtype = struct.field(
        pytree_node=False
    )  # Dtype for the quantized value.
    extract_diagonal: bool = struct.field(pytree_node=False)  # In case its centered.
    shape: Any = struct.field(pytree_node=False)  # Shape of the tensor.

    @classmethod
    def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
        if isinstance(fvalue, list) and not fvalue:
            return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
        quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
            fvalue, quantized_dtype, extract_diagonal
        )
        return QuantizedValue(
            quantized,
            diagonal_fvalue,
            bucket_size,
            quantized_dtype,
            extract_diagonal,
            list(quantized.shape),
        )

    # Quantization is from Lingvo JAX optimizers.
    # We extend it for int16 quantization of PSD matrices.
    @classmethod
    def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
        """Returns quantized value and the bucket."""
        if quantized_dtype == jnp.float32:
            return fvalue, [], []
        elif quantized_dtype == jnp.bfloat16:
            return fvalue.astype(jnp.bfloat16), [], []

        float_dtype = fvalue.dtype
        if quantized_dtype == jnp.int8:
            # value -128 is not used.
            num_buckets = jnp.array(127.0, dtype=float_dtype)
        elif quantized_dtype == jnp.int16:
            # value -32768 is not used.
            num_buckets = jnp.array(32767.0, dtype=float_dtype)
        else:
            raise ValueError(f"Quantized dtype {quantized_dtype} not supported.")
        # max value is mapped to num_buckets

        if extract_diagonal and fvalue.ndim != 2:
            raise ValueError(
                f"Input array {fvalue} must be 2D to work with extract_diagonal."
            )

        diagonal_fvalue = []
        if extract_diagonal:
            diagonal_fvalue = jnp.diag(fvalue)
            # Remove the diagonal entries.
            fvalue = fvalue - jnp.diag(diagonal_fvalue)

        # TODO(rohananil): Extend this by making use of information about the blocks
        # SM3 style which will be useful for diagonal statistics
        # We first decide the scale.
        if fvalue.ndim < 1:
            raise ValueError(
                f"Input array {fvalue} must have a strictly positive number of "
                "dimensions."
            )

        max_abs = jnp.max(jnp.abs(fvalue), axis=0)
        bucket_size = max_abs / num_buckets
        bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
        # To avoid divide by 0.0
        bs_nonzero = jnp.where(
            bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
        )
        ratio = fvalue / bs_nonzero
        # We use rounding to remove bias.
        quantized = jnp.round(ratio)
        return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size

    def to_float(self):
        """Returns the float value."""
        if isinstance(self.quantized, list) and not self.quantized:
            return self.quantized

        if self.quantized_dtype == jnp.float32:
            return self.quantized

        if self.quantized_dtype == jnp.bfloat16:
            return self.quantized.astype(jnp.float32)

        float_dtype = self.bucket_size.dtype
        bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
        val = self.quantized.astype(float_dtype) * bucket_size
        if self.extract_diagonal:
            val += jnp.diag(self.diagonal)
        return val
